Amazonโs Deep Java Library enables Java developers to more easily create machine learning and deep learning models.
Deep Java Library (DJL), is an open-source library created by Amazon to develop machine learning (ML) and deep learning (DL) models natively in Java while simplifying the use of deep learning frameworks.
I recently used DJL to develop a footwear classification model and found the toolkit super intuitive and easy to use; itโs obvious a lot of thought went into the design and how Java developers would use it. DJL APIs abstract commonly used functions to develop models and orchestrate infrastructure management. I found the high-level APIs used to train, test and run inference allowed me to use my knowledge of Java and the ML lifecycle to develop a model in less than an hour with minimal code.
Footwear classification model
The footwear classification model is a multiclass classification computer vision (CV) model, trained using supervised learning that classifies footwear in one of four class labels: boots, sandals, shoes, or slippers.
AWSAbout the data
The most important part of developing an accurate ML model is to use data from a reputable source. The data source for the footwear classification model is theย UTZappos50kย dataset provided byย The University of Texas at Austinย and is freely available for academic, non-commercial use. The shoe dataset consists of 50,025 labeled catalog images collected fromย Zappos.com.
Train the footwear classification model
Training is the process to produce an ML model by giving a learning algorithm training data to study. The term model refers to the artifact produced during the training process; the model contains patterns found in the training data and can be used to make a prediction (or inference). Before I started the training process, I set up my local environment for development. You will need JDK 8 (or later), IntelliJ, an ML engine for training (likeย Apache MXNet), an environment variable pointed to your engineโs path and the build dependencies for DJL.
AWSDJL stays true to Javaโs motto, โwrite once, run anywhere (WORA)โ, by being engine and deep learning framework-agnostic. Developers can write code once that runs on any engine. DJL currently provides an implementation for Apache MXNet, an ML engine that eases the development of deep neural networks. DJL APIs use JNA, Java Native Access, to call the corresponding Apache MXNet operations. From a hardware perspective, training occurred locally on my laptop using a CPU. However, for the best performance, the DJL team recommends using a machine with at least one GPU. If you donโt have a GPU available to you, there is always an option to use Apache MXNet onย Amazon EC2. A nice feature of DJL is that it provides automatic CPU/GPU detection based on the hardware configuration to always ensure the best performance.
Load dataset from the source
The footwear data was saved locally and loaded using DJLย ImageFolderย dataset, which is a dataset that can retrieve images from a local folder. In DJL terms, aย Datasetย simply holds the training data. There are dataset implementations that can be used to download data (based on the URL you provide), extract data, and automatically separate data into training and validation sets. The automatic separation is a useful feature as it is important to never use the same data the model was trained with to validate the modelโs performance. The training validation dataset is used to find patterns in the data; the validation dataset is used to estimate the footwear modelโs accuracy during the training process.
AWSWhen structuring the data locally, I didnโt go down to the most granular level identified by the UTZappos50k dataset, such as the ankle, knee-high, mid-calf, over the knee, etc. classification labels for boots. My local data are kept at the highest level of classification, which includes only boots, sandals, shoes, and slippers.
AWSTrain the model
Now that I have the footwear data separated into training and validation sets, I will use a neural network to train the model.
AWSTraining is started by feeding the training data as input to aย Block. In DJL terms, aย Blockย is a composable unit that forms a neural network. You can combine Blocks (just like Lego blocks) to form a complex network. At the end of the training process, aย Blockย represents a fully-trained model. The first step is to get a model instance by callingย Models.getModel(NUM_OF_OUTPUT, NEW_HEIGHT, NEW_WIDTH). Theย getModel()ย method creates an empty model, constructs the neural network, and sets the neural network to the model.
AWSThe next step is to set up and configure aย Trainerย by calling theย model.newTrainer(config)ย method. The config object was initialized by calling theย setupTrainingConfig(loss)ย method, which sets the training configuration (or hyperparameters) to determine how the network is trained.
AWSThere are multiple hyperparameters set for training:
- newHeightandย newWidthย โ the shape of the image.
- batchSizeโ the batch size used for training; pick a proper size based on your model.
- numOfOutputโ the number of labels; there are 4 labels for footwear classification.
- lossโ loss functions evaluate model predictions against true labels measuring how good (or bad) a model is.
- Initializerโ identifies an initialization method; in this case, Xavier initialization.
- MultiFactorTrackerโ configures the learning rate options.
- Optimizer: an optimization technique to minimize the value of the loss function; in this case, stochastic gradient descent (SGD).
The next step is to setย Metrics, a training listener, and initialize theย Trainerย with the proper input shape.ย Metricsย collect and report key performance indicators (KPIs) during training that can be used to analyze and monitor training performance and stability. Next, I kick off the training process by calling theย fit(trainer, trainingDataset, validateDataset,ย โbuild/logs/trainingโ)ย method, which iterates over the training data and stores the patterns found in the model.
AWSAt the end of the training, a well-performing validated model artifact is saved locally along with its properties using theย model.save(Paths.get(modelParamsPath), modelParamsName)method. The metrics reported during the training process are shown below.
AWSRun inference
Now that I have a model, I can use it to perform inference (or prediction) on new data for which I do not know the classification (or target). After setting the necessary paths to the model and the image to be classified, I obtain an empty model instance using theย Models.getModel(NUM_OF_OUTPUT, NEW_HEIGHT, NEW_WIDTH)ย method and initialize it using theย model.load(Paths.get(modelParamsPath), modelParamsName)ย method. This loads the model I trained in the previous step. Next, Iโm initializing aย Predictor, with a specifiedย Translator, using theย model.newPredictor(translator)method. Youโll notice that Iโm passing aย Translatorย to theย Predictor. In DJL terms, aย Translatorย provides model pre-processing and post-processing functionality. For example, with CV models, images need to be reshaped to grayscale; aย Translatorย can do this for you. Theย Predictorย allows me to perform inference on the loadedย Modelย using theย predictor.predict(img)ย method, passing in the image to classify. Iโm doing a single prediction, but DJL also supports batch predictions. The inference is stored inย predictResult, which contains the probability estimate per label. The model is automatically closed once inference completes making DJL memory efficient.
AWSThe inferences (per image) are shown below with their corresponding probability scores.
AWSTakeaways & Next Steps
Iโve been developing Java-based applications since the late โ90s and started my machine learning journey in 2017. My journey wouldโve been much easier had DJL been around back then. I highly recommend that Java developers, looking to transition to machine learning, give DJL a try. In my example, I developed the footwear classification model from scratch; however, DJL also allows developers to deploy pre-trained models with minimal effort. DJL also comes with popular datasets out of the box to allow developers to instantly get started with ML. Before starting with DJL, I would recommend that you have a firm understanding of the ML lifecycle and are familiar with common ML terms. Once you have a basic level understanding of ML, you can quickly come up to speed on DJL APIs.
Amazon has open-sourced DJL, where further detailed information about the toolkit can be found on the DJLย websiteย andย Java Library API Specificationย page. The code for the footwear classification model can be found onย GitLab. Good luck on your ML journey and please feel free to reach out to me if you have any questions.



