import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtSession;
import org.opencv.core.*;
import org.opencv.imgcodecs.Imgcodecs;
import org.opencv.imgproc.Imgproc;
import org.opencv.videoio.VideoCapture;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
public class YoloDetector {
private static final String YOLO_MODEL_PATH = "./model/yolov8n.onnx";
private static final int INPUT_SIZE = 416;
private static final float CONF_THRESHOLD = 0.5f;
private static final float IOU_THRESHOLD = 0.45f;
private static final List<String> CLASS_NAMES = List.of(
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat",
"traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat",
"dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack",
"umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball",
"kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket",
"bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple",
"sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake",
"chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop",
"mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink",
"refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier",
"toothbrush"
);
static {
try {
System.loadLibrary(Core.NATIVE_LIBRARY_NAME);
System.out.println("[成功] OpenCV 自动加载完成,版本:" + Core.VERSION);
} catch (Exception e) {
throw new RuntimeException("[失败] OpenCV 加载失败,请检查 Maven 依赖", e);
}
}
private OrtEnvironment ortEnv;
private OrtSession ortSession;
public YoloDetector() {
try {
ortEnv = OrtEnvironment.getEnvironment();
ortSession = ortEnv.createSession(YOLO_MODEL_PATH, new OrtSession.SessionOptions());
System.out.println("[成功] YOLO 模型加载完成");
} catch (Exception e) {
throw new RuntimeException("[失败] YOLO 模型加载失败,请检查模型路径", e);
}
}
private Mat preprocess(Mat originalImg) {
Mat resizedImg = new Mat();
Imgproc.resize(originalImg, resizedImg, new Size(INPUT_SIZE, INPUT_SIZE));
Imgproc.cvtColor(resizedImg, resizedImg, Imgproc.COLOR_BGR2RGB);
resizedImg.convertTo(resizedImg, CvType.CV_32FC3, 1.0 / 255.0);
List<Mat> channels = new ArrayList<>();
Core.split(resizedImg, channels);
Mat chwImg = new Mat();
Core.merge(channels, chwImg);
return chwImg.reshape(1, 1);
}
private float[] infer(Mat inputMat) {
try {
OrtSession.InputTensor inputTensor = OrtSession.InputTensor.createTensor(
ortEnv, inputMat.getNativeObjAddr(),
new long[]{1, 3, INPUT_SIZE, INPUT_SIZE},
org.onnxruntime.TypeInfo.TensorType.FLOAT
);
OrtSession.Result result = ortSession.run(Collections.singletonMap(
ortSession.getInputNames().iterator().next(), inputTensor
));
return (float[]) result.getOutputs().get(0).get().getObject();
} catch (Exception e) {
throw new RuntimeException("推理失败", e);
}
}
private List<DetectionResult> postprocess(float[] outputs, Mat originalImg) {
List<DetectionResult> validResults = new ArrayList<>();
int imgWidth = originalImg.cols();
int imgHeight = originalImg.rows();
int numClasses = CLASS_NAMES.size();
int numDetections = outputs.length / (5 + numClasses);
for (int i = 0; i < numDetections; i++) {
int offset = i * (5 + numClasses);
float x = outputs[offset];
float y = outputs[offset + 1];
float w = outputs[offset + 2];
float h = outputs[offset + 3];
float conf = outputs[offset + 4];
if (conf < CONF_THRESHOLD) continue;
float maxClsConf = 0;
int clsId = 0;
for (int j = 0; j < numClasses; j++) {
float clsConf = outputs[offset + 5 + j];
if (clsConf > maxClsConf) {
maxClsConf = clsConf;
clsId = j;
}
}
int x1 = (int) ((x - w / 2) * imgWidth);
int y1 = (int) ((y - h / 2) * imgHeight);
int x2 = (int) ((x + w / 2) * imgWidth);
int y2 = (int) ((y + h / 2) * imgHeight);
x1 = Math.max(0, x1); y1 = Math.max(0, y1);
x2 = Math.min(imgWidth - 1, x2); y2 = Math.min(imgHeight - 1, y2);
validResults.add(new DetectionResult(x1, y1, x2, y2, conf * maxClsConf, clsId));
}
return nms(validResults);
}
private List<DetectionResult> nms(List<DetectionResult> results) {
List<DetectionResult> nmsResults = new ArrayList<>();
if (results.isEmpty()) return nmsResults;
Collections.sort(results, (a, b) -> Float.compare(b.confidence, a.confidence));
boolean[] suppressed = new boolean[results.size()];
for (int i = 0; i < results.size(); i++) {
if (suppressed[i]) continue;
DetectionResult current = results.get(i);
nmsResults.add(current);
for (int j = i + 1; j < results.size(); j++) {
if (suppressed[j]) continue;
DetectionResult other = results.get(j);
if (calculateIoU(current, other) > IOU_THRESHOLD) {
suppressed[j] = true;
}
}
}
return nmsResults;
}
private float calculateIoU(DetectionResult a, DetectionResult b) {
int interX1 = Math.max(a.x1, b.x1);
int interY1 = Math.max(a.y1, b.y1);
int interX2 = Math.min(a.x2, b.x2);
int interY2 = Math.min(a.y2, b.y2);
if (interX2 <= interX1 || interY2 <= interY1) return 0;
float interArea = (interX2 - interX1) * (interY2 - interY1);
float aArea = (a.x2 - a.x1) * (a.y2 - a.y1);
float bArea = (b.x2 - b.x1) * (b.y2 - b.y1);
return interArea / (aArea + bArea - interArea);
}
private void drawResults(Mat img, List<DetectionResult> results) {
for (DetectionResult result : results) {
Imgproc.rectangle(img, new Point(result.x1, result.y1),
new Point(result.x2, result.y2), new Scalar(0, 255, 0), 2);
String label = CLASS_NAMES.get(result.clsId) + " " + String.format("%.2f", result.confidence);
Imgproc.putText(img, label, new Point(result.x1, result.y1 - 10),
Imgproc.FONT_HERSHEY_SIMPLEX, 0.5, new Scalar(0, 255, 0), 1);
}
}
public void detectImage(String imgPath, String savePath) {
Mat originalImg = Imgcodecs.imread(imgPath);
if (originalImg.empty()) {
System.err.println("无法读取图片:" + imgPath);
return;
}
Mat inputMat = preprocess(originalImg);
float[] outputs = infer(inputMat);
List<DetectionResult> results = postprocess(outputs, originalImg);
drawResults(originalImg, results);
Imgcodecs.imwrite(savePath, originalImg);
System.out.println("图片检测完成!结果保存至:" + savePath);
System.out.println("检测到目标数量:" + results.size());
}
public void detectCamera() {
VideoCapture cap = new VideoCapture(0);
if (!cap.isOpened()) {
System.err.println("无法打开摄像头");
return;
}
Mat frame = new Mat();
while (true) {
cap.read(frame);
if (frame.empty()) break;
Mat inputMat = preprocess(frame);
float[] outputs = infer(inputMat);
List<DetectionResult> results = postprocess(outputs, frame);
drawResults(frame, results);
Imgproc.imshow("YOLOv8 Real-Time Detection", frame);
if (Imgproc.waitKey(1) == 27) break;
}
cap.release();
Imgproc.destroyAllWindows();
}
static class DetectionResult {
int x1, y1, x2, y2;
float confidence;
int clsId;
public DetectionResult(int x1, int y1, int x2, int y2, float confidence, int clsId) {
this.x1 = x1;
this.y1 = y1;
this.x2 = x2;
this.y2 = y2;
this.confidence = confidence;
this.clsId = clsId;
}
}
public static void main(String[] args) {
YoloDetector detector = new YoloDetector();
detector.detectImage("test.jpg", "result.jpg");
}
}