-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathtensorflowlite.h
83 lines (70 loc) · 2.2 KB
/
tensorflowlite.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#ifndef TENSORFLOW_H
#define TENSORFLOW_H
#include <QStringList>
#include <QImage>
#include <QRectF>
#include "tensorflow/lite/error_reporter.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/graph_info.h"
#include "tensorflow/lite/kernels/register.h"
using namespace tflite;
class TensorflowLite
{
public:
static const int knIMAGE_CLASSIFIER = 1;
static const int knOBJECT_DETECTION = 2;
TensorflowLite();
bool init();
double getThreshold() const;
void setThreshold(double value);
QStringList getResults();
QList<double> getConfidence();
QList<QRectF> getBoxes();
int getKindNetwork();
bool run(QImage img);
QString getModelFilename() const;
void setModelFilename(const QString &value);
QString getLabelsFilename() const;
void setLabelsFilename(const QString &value);
int getImgHeight() const;
int getImgWidth() const;
double getInfTime() const;
int getNThreads() const;
void setNThreads(int value);
bool getAcceleration() const;
void setAcceleration(bool value);
private:
bool initialized;
double threshold;
int nThreads;
bool acceleration;
// Results
QStringList rCaption;
QList<double> rConfidence;
QList<QRectF> rBox;
double infTime;
int kind_network;
std::vector<TfLiteTensor *> outputs;
std::unique_ptr<Interpreter> interpreter;
std::unique_ptr<FlatBufferModel> model;
ops::builtin::BuiltinOpResolver resolver;
StderrReporter error_reporter;
int wanted_height, wanted_width, wanted_channels;
bool inference();
bool setInputs(QImage image);
bool getClassfierOutputs(std::vector<std::pair<float, int> > *top_results);
bool getObjectOutputs(QStringList &captions, QList<double> &confidences, QList<QRectF> &locations);
bool readLabels();
QString input_name;
TfLiteType input_dtype;
std::unique_ptr<TfLiteTensor> input_tensor;
QString modelFilename;
QString labelsFilename;
QStringList labels;
QString getLabel(int index);
int img_height, img_width, img_channels;
const QImage::Format format = QImage::Format_RGB888;
const int numChannels = 3;
};
#endif // TENSORFLOW_H