OpenCV  5.0.0-pre
Open Source Computer Vision
Loading...
Searching...
No Matches
samples/dnn/text_detection.cpp
/*
Text detection model (EAST): https://github.com/argman/EAST
Download link for EAST model: https://www.dropbox.com/s/r2ingd0l3zt8hxs/frozen_east_text_detection.tar.gz?dl=1
DB detector model:
https://drive.google.com/uc?export=download&id=17_ABp79PlFt9yPCxSaarVc_DKTmrSGGf
CRNN Text recognition model sourced from: https://github.com/meijieru/crnn.pytorch
How to convert from .pb to .onnx:
Using classes from: https://github.com/meijieru/crnn.pytorch/blob/master/models/crnn.py
Additional converted ONNX text recognition models available for direct download:
Download link: https://drive.google.com/drive/folders/1cTbQ3nuZG-EKWak6emD_s8_hHXWz7lAr?usp=sharing
These models are taken from: https://github.com/clovaai/deep-text-recognition-benchmark
Importing and using the CRNN model in PyTorch:
import torch
from models.crnn import CRNN
model = CRNN(32, 1, 37, 256)
model.load_state_dict(torch.load('crnn.pth'))
dummy_input = torch.randn(1, 1, 32, 100)
torch.onnx.export(model, dummy_input, "crnn.onnx", verbose=True)
Usage: ./example_dnn_text_detection DB
*/
#include <iostream>
#include <fstream>
#include <opencv2/dnn.hpp>
#include "common.hpp"
using namespace cv;
using namespace std;
using namespace cv::dnn;
const string about = "Use this script for Text Detection and Recognition using OpenCV. \n\n"
"Firstly, download required models using `download_models.py` (if not already done). Set environment variable OPENCV_DOWNLOAD_CACHE_DIR to point to the directory where models are downloaded. Also, point OPENCV_SAMPLES_DATA_PATH to opencv/samples/data.\n"
"To run:\n"
"\t Example: ./example_dnn_text_detection modelName(i.e. DB or East) --ocr_model=<path to VGG_CTC.onnx>\n\n"
"Detection model path can also be specified using --model argument. \n\n"
"Download ocr model using: python download_models.py OCR \n\n";
// Command-line keys to parse the input arguments
string keys =
"{ help h | | Print help message. }"
"{ input i | right.jpg | Path to an input image. }"
"{ @alias | | An alias name of model to extract preprocessing parameters from models.yml file. }"
"{ zoo | ../dnn/models.yml | An optional path to file with preprocessing parameters }"
"{ ocr_model | | Path to a binary .onnx model for recognition. }"
"{ model | | Path to detection model file. }"
"{ thr | 0.5 | Confidence threshold for EAST detector. }"
"{ nms | 0.4 | Non-maximum suppression threshold for EAST detector. }"
"{ binaryThreshold bt | 0.3 | Confidence threshold for the binary map in DB detector. }"
"{ polygonThreshold pt | 0.5 | Confidence threshold for polygons in DB detector. }"
"{ maxCandidate max | 200 | Max candidates for polygons in DB detector. }"
"{ unclipRatio ratio | 2.0 | Unclip ratio for DB detector. }"
"{ vocabularyPath vp | alphabet_36.txt | Path to vocabulary file. }";
// Function prototype for the four-point perspective transform
static void fourPointsTransform(const Mat& frame, const Point2f vertices[], Mat& result);
static void processFrame(
const Mat& frame,
const vector<vector<Point>>& detResults,
const std::string& ocr_model,
bool imreadRGB,
Mat& board,
FontFace& fontFace,
int fontSize,
int fontWeight,
const vector<std::string>& vocabulary
);
int main(int argc, char** argv) {
// Setting up command-line parser with the specified keys
CommandLineParser parser(argc, argv, keys);
if (!parser.has("@alias") || parser.has("help"))
{
cout << about << endl;
parser.printMessage();
return -1;
}
const string modelName = parser.get<String>("@alias");
const string zooFile = findFile(parser.get<String>("zoo"));
keys += genPreprocArguments(modelName, zooFile, "");
keys += genPreprocArguments(modelName, zooFile, "ocr_");
parser = CommandLineParser(argc, argv, keys);
parser.about(about);
// Parsing command-line arguments
String sha1 = parser.get<String>("download_sha");
if (sha1.empty()){
sha1 = parser.get<String>("sha1");
}
String ocr_sha1 = parser.get<String>("ocr_sha1");
String detModelPath = findModel(parser.get<String>("model"), sha1);
String ocr = findModel(parser.get<String>("ocr_model"), ocr_sha1);
int height = parser.get<int>("height");
int width = parser.get<int>("width");
bool imreadRGB = parser.get<bool>("rgb");
String vocPath = parser.get<String>("vocabularyPath");
float binThresh = parser.get<float>("binaryThreshold");
float polyThresh = parser.get<float>("polygonThreshold");
double unclipRatio = parser.get<double>("unclipRatio");
uint maxCandidates = parser.get<uint>("maxCandidate");
float confThreshold = parser.get<float>("thr");
float nmsThreshold = parser.get<float>("nms");
Scalar mean = parser.get<Scalar>("mean");
// Ensuring the provided arguments are valid
if (!parser.check()) {
parser.printErrors();
return 1;
}
// Asserting detection model path is provided
CV_Assert(!detModelPath.empty());
vector<vector<Point>> detResults;
// Reading the input image
Mat frame = imread(samples::findFile(parser.get<String>("input")));
Mat board(frame.size(), frame.type(), Scalar(255, 255, 255));
int stdSize = 20;
int stdWeight = 400;
int stdImgSize = 512;
int imgWidth = min(frame.rows, frame.cols);
int size = (stdSize*imgWidth)/stdImgSize;
int weight = (stdWeight*imgWidth)/stdImgSize;
FontFace fontFace("sans");
// Initializing and configuring the text detection model based on the provided config
if (modelName == "East") {
// EAST Detector initialization
TextDetectionModel_EAST detector(detModelPath);
detector.setConfidenceThreshold(confThreshold)
.setNMSThreshold(nmsThreshold);
// Setting input parameters specific to EAST model
detector.setInputParams(1.0, Size(width, height), mean, true);
// Performing text detection
detector.detect(frame, detResults);
}
else if (modelName == "DB") {
// DB Detector initialization
TextDetectionModel_DB detector(detModelPath);
detector.setBinaryThreshold(binThresh)
.setPolygonThreshold(polyThresh)
.setUnclipRatio(unclipRatio)
.setMaxCandidates(maxCandidates);
// Setting input parameters specific to DB model
detector.setInputParams(1.0 / 255.0, Size(width, height), mean);
// Performing text detection
detector.detect(frame, detResults);
}
else {
cout << "[ERROR]: Unsupported file config for the detector model. Valid values: east/db" << endl;
return 1;
}
// Reading and storing vocabulary for text recognition
CV_Assert(!vocPath.empty());
ifstream vocFile;
vocFile.open(samples::findFile(vocPath));
CV_Assert(vocFile.is_open());
std::string vocLine;
vector<std::string> vocabulary;
while (getline(vocFile, vocLine)) {
vocabulary.push_back(vocLine);
}
processFrame(frame, detResults, ocr, imreadRGB, board, fontFace, size, weight, vocabulary);
return 0;
}
// Performs a perspective transform for a four-point region
static void fourPointsTransform(const Mat& frame, const Point2f vertices[], Mat& result) {
const Size outputSize = Size(100, 32);
// Defining target vertices for the perspective transform
Point2f targetVertices[4] = {
Point(0, outputSize.height - 1),
Point(0, 0),
Point(outputSize.width - 1, 0),
Point(outputSize.width - 1, outputSize.height - 1)
};
// Computing the perspective transform matrix
Mat rotationMatrix = getPerspectiveTransform(vertices, targetVertices);
// Applying the perspective transform to the region
warpPerspective(frame, result, rotationMatrix, outputSize);
}
void processFrame(
const Mat& frame,
const vector<vector<Point>>& detResults,
const std::string& ocr_model,
bool imreadRGB,
Mat& board,
FontFace& fontFace,
int fontSize,
int fontWeight,
const vector<std::string>& vocabulary
) {
if (detResults.size() > 0) {
// Text Recognition
Mat recInput;
if (!imreadRGB) {
cvtColor(frame, recInput, cv::COLOR_BGR2GRAY);
} else {
recInput = frame;
}
vector<vector<Point>> contours;
for (uint i = 0; i < detResults.size(); i++) {
const auto& quadrangle = detResults[i];
CV_CheckEQ(quadrangle.size(), (size_t)4, "");
contours.emplace_back(quadrangle);
vector<Point2f> quadrangle_2f;
for (int j = 0; j < 4; j++)
quadrangle_2f.emplace_back(detResults[i][j]);
// Cropping the detected text region using a four-point transform
Mat cropped;
fourPointsTransform(recInput, &quadrangle_2f[0], cropped);
if(!ocr_model.empty()){
TextRecognitionModel recognizer(ocr_model);
recognizer.setVocabulary(vocabulary);
recognizer.setDecodeType("CTC-greedy");
// Setting input parameters for the recognition model
double recScale = 1.0 / 127.5;
Scalar recMean = Scalar(127.5);
Size recInputSize = Size(100, 32);
recognizer.setInputParams(recScale, recInputSize, recMean);
// Recognizing text from the cropped image
string recognitionResult = recognizer.recognize(cropped);
cout << i << ": '" << recognitionResult << "'" << endl;
// Displaying the recognized text on the image
putText(board, recognitionResult, Point(detResults[i][1].x, detResults[i][0].y), Scalar(0, 0, 0), fontFace, fontSize, fontWeight);
}
else{
cout << "[WARN] Please pass the path to the ocr model using --ocr_model to get the recognised text." << endl;
}
}
// Drawing detected text regions on the image
polylines(board, contours, true, Scalar(200, 255, 200), 1);
polylines(frame, contours, true, Scalar(0, 255, 0), 1);
} else {
cout << "No Text Detected." << endl;
}
// Displaying the final image with detected and recognized text
Mat stacked;
hconcat(frame, board, stacked);
imshow("Text Detection and Recognition", stacked);
waitKey(0);
}
#define CV_CheckEQ(v1, v2, msg)
Supported values of these types: int, float, double.
Definition check.hpp:120
Designed for command line parsing.
Definition utility.hpp:890
Wrapper on top of a truetype/opentype/etc font, i.e. Freetype's FT_Face.
Definition imgproc.hpp:4991
n-dimensional dense array class
Definition mat.hpp:950
MatSize size
Definition mat.hpp:2447
int cols
Definition mat.hpp:2424
int rows
the number of rows and columns or (-1, -1) when the matrix has more than 2 dimensions
Definition mat.hpp:2424
int type() const
Returns the type of a matrix element.
Template class for specifying the size of an image or rectangle.
Definition types.hpp:338
_Tp height
the height
Definition types.hpp:366
_Tp width
the width
Definition types.hpp:365
This class represents high-level API for text detection DL networks compatible with DB model.
Definition dnn.hpp:2029
This class represents high-level API for text detection DL networks compatible with EAST model.
Definition dnn.hpp:1968
This class represents high-level API for text recognition networks.
Definition dnn.hpp:1812
Scalar mean(InputArray src, InputArray mask=noArray())
Calculates an average (mean) of array elements.
void min(InputArray src1, InputArray src2, OutputArray dst)
Calculates per-element minimum of two arrays or an array and a scalar.
void hconcat(const Mat *src, size_t nsrc, OutputArray dst)
Applies horizontal concatenation to given matrices.
std::string String
Definition cvstd.hpp:151
uint32_t uint
Definition interface.h:42
#define CV_Assert(expr)
Checks a condition at runtime and throws exception if it fails.
Definition exception.hpp:198
void imshow(const String &winname, InputArray mat)
Displays an image in the specified window.
int waitKey(int delay=0)
Waits for a pressed key.
CV_EXPORTS_W Mat imread(const String &filename, int flags=IMREAD_COLOR_BGR)
Loads an image from a file.
void cvtColor(InputArray src, OutputArray dst, int code, int dstCn=0, AlgorithmHint hint=cv::ALGO_HINT_DEFAULT)
Converts an image from one color space to another.
@ COLOR_BGR2GRAY
convert between RGB/BGR and grayscale, color conversions
Definition imgproc.hpp:559
void putText(InputOutputArray img, const String &text, Point org, int fontFace, double fontScale, Scalar color, int thickness=1, int lineType=LINE_8, bool bottomLeftOrigin=false)
Draws a text string.
void polylines(InputOutputArray img, InputArrayOfArrays pts, bool isClosed, const Scalar &color, int thickness=1, int lineType=LINE_8, int shift=0)
Draws several polygonal curves.
void warpPerspective(InputArray src, OutputArray dst, InputArray M, Size dsize, int flags=INTER_LINEAR, int borderMode=BORDER_CONSTANT, const Scalar &borderValue=Scalar(), AlgorithmHint hint=cv::ALGO_HINT_DEFAULT)
Applies a perspective transformation to an image.
Mat getPerspectiveTransform(InputArray src, InputArray dst, int solveMethod=DECOMP_LU)
Calculates a perspective transform from four pairs of the corresponding points.
int main(int argc, char *argv[])
Definition highgui_qt.cpp:3
Definition all_layers.hpp:47
GOpaque< Size > size(const GMat &src)
Gets dimensions from Mat.
Definition core.hpp:107
STL namespace.