DJL - Deep Java Library 亚马逊在2019年宣布推出的 开源的 深度学习 开发包, 它是在现有深度学习框架基础上使用原生Java概念构建的开发库 支持MXnet,Tensorflow,Pytorch http://docs.djl.ai/engines/pytorch/pytorch-engine/index.html
DJL公共依赖包
Sets environment variable: PYTORCH_VERSION to override the default package version.
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-engine</artifactId>
<version>0.22.1</version>
<scope>runtime</scope>
</dependency>
DJL Supported PyTorch versions
Since DJL 0.14.0, pytorch-engine can load older version of pytorch native library. PyTorch engine version PyTorch native library version pytorch-engine:0.22.1 1.11.0, 1.12.1, 1.13.1, 2.0.0 pytorch-engine:0.21.0 1.11.0, 1.12.1, 1.13.1 pytorch-engine:0.20.0 1.11.0, 1.12.1, 1.13.0 pytorch-engine:0.19.0 1.10.0, 1.11.0, 1.12.1 pytorch-engine:0.18.0 1.9.1, 1.10.0, 1.11.0 pytorch-engine:0.17.0 1.9.1, 1.10.0, 1.11.0 pytorch-engine:0.16.0 1.8.1, 1.9.1, 1.10.0 新的pytorch-engine可以支持旧的pytorch模型, 就是可以向前兼容, 那就尽量下载最新的pytorch-engine
Windows CPU
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-native-cpu</artifactId>
<classifier>win-x86_64</classifier>
<scope>runtime</scope>
<version>2.0.0</version>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-jni</artifactId>
<version>2.0.0-0.22.1</version>
<scope>runtime</scope>
</dependency>
Windows GPU
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-native-cu118</artifactId>
<classifier>win-x86_64</classifier>
<version>2.0.0</version>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-jni</artifactId>
<version>2.0.0-0.22.1</version>
<scope>runtime</scope>
</dependency>
Linux CPU
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-native-cpu</artifactId>
<classifier>linux-x86_64</classifier>
<scope>runtime</scope>
<version>2.0.0</version>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-jni</artifactId>
<version>2.0.0-0.22.1</version>
<scope>runtime</scope>
</dependency>
Linux GPU
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-native-cu118</artifactId>
<classifier>linux-x86_64</classifier>
<version>2.0.0</version>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-jni</artifactId>
<version>2.0.0-0.22.1</version>
<scope>runtime</scope>
</dependency>
macOS M1
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-native-cpu</artifactId>
<classifier>osx-aarch64</classifier>
<version>2.0.0</version>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-jni</artifactId>
<version>2.0.0-0.22.1</version>
<scope>runtime</scope>
</dependency>
Pre-trained models
The PyTorch model zoo contains Computer Vision (CV) models. All the models are grouped by task under these two categories as follows: CV Image Classification Object Detection Style Transfer Image Generation DJL Model Zoo
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-model-zoo</artifactId>
<version>0.22.1</version>
</dependency>
pom
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>org.example</groupId>
<artifactId>w11</artifactId>
<version>1.0-SNAPSHOT</version>
<packaging>jar</packaging>
<name>w11</name>
<url>http://maven.apache.org</url>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
</properties>
<dependencies>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>api</artifactId>
<version>0.22.1</version>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-engine</artifactId>
<version>0.22.1</version>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-native-cpu</artifactId>
<classifier>win-x86_64</classifier>
<scope>runtime</scope>
<version>2.0.0</version>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-jni</artifactId>
<version>2.0.0-0.22.1</version>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-nop</artifactId>
<version>1.7.2</version>
<type>jar</type>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>3.8.1</version>
<scope>test</scope>
</dependency>
</dependencies>
</project>
code
package org.example;
import java.nio.file.*;
import java.awt.image.*;
import ai.djl.*;
import ai.djl.inference.*;
import ai.djl.modality.*;
import ai.djl.modality.cv.*;
import ai.djl.modality.cv.util.*;
import ai.djl.modality.cv.transform.*;
import ai.djl.modality.cv.translator.*;
import ai.djl.repository.zoo.*;
import ai.djl.translate.*;
import ai.djl.training.util.*;
public class App
{
public static void main( String[] args )
{
try{
DownloadUtils.download("https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/resnet/0.0.1/traced_resnet18.pt.gz", "build/pytorch_models/resnet18/resnet18.pt", new ProgressBar());
DownloadUtils.download("https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/synset.txt", "build/pytorch_models/resnet18/synset.txt", new ProgressBar());
/*
*torch对应的图像预处理
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
* */
Translator<Image, Classifications> translator = ImageClassificationTranslator.builder()
.addTransform(new Resize(256))
.addTransform(new CenterCrop(224, 224))
.addTransform(new ToTensor())
.addTransform(new Normalize(
new float[] {0.485f, 0.456f, 0.406f},
new float[] {0.229f, 0.224f, 0.225f}))
.optApplySoftmax(true)
.build();
Criteria<Image, Classifications> criteria = Criteria.builder()
.setTypes(Image.class, Classifications.class)
.optModelPath(Paths.get("build/pytorch_models/resnet18"))
.optOption("mapLocation", "true") // this model requires mapLocation for GPU
.optTranslator(translator)
.optProgress(new ProgressBar()).build();
ZooModel model = criteria.loadModel();
// Image img = ImageFactory.getInstance().fromUrl("https://resources.djl.ai/images/0.png");
//0.png是手与数字识别其中一张0的图片
Image img = ImageFactory.getInstance().fromUrl("build/pytorch_models/resnet18/0.png");
img.getWrappedImage();
Predictor<Image, Classifications> predictor = model.newPredictor();
Classifications classifications = predictor.predict(img);
System.out.println(classifications);
}catch (Exception e){
}
System.out.println( "Hello World!" );
}
}
环境描述
pytorch:1.10.2 DJL:0.22.1 OS:先在windows上开发,然后部署到Cento7上运行
linux 项目创建
mvn archetype:generate -DgroupId=org.test -DartifactId=lnx1 -DarchetypeArtifactId=maven-archetype-quickstart -DinteractiveMode=false
python生成pytorch模型
import torch
import torchvision
model = torchvision.models.resnet50()
model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("traced_model_resnet50.pt")
注意:不能使用torch.save,要使用torch.jit.trace