forked from google-coral/libedgetpu
-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathdriver_helper.h
206 lines (159 loc) · 7.4 KB
/
driver_helper.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef DARWINN_DRIVER_DRIVER_HELPER_H_
#define DARWINN_DRIVER_DRIVER_HELPER_H_
#include <unistd.h>
#include <chrono> // NOLINT
#include <condition_variable> // NOLINT
#include <fstream>
#include <memory>
#include <string>
#include <vector>
#include "api/buffer.h"
#include "api/chip.h"
#include "api/driver.h"
#include "api/package_reference.h"
#include "api/request.h"
#include "api/telemeter_interface.h"
#include "api/timing.h"
#include "driver/executable_util.h"
#include "driver/package_registry.h"
#include "driver/test_vector.h"
#include "executable/executable_generated.h"
#include "port/errors.h"
#include "port/logging.h"
#include "port/ptr_util.h"
#include "port/status.h"
#include "port/status_macros.h"
#include "port/std_mutex_lock.h"
#include "port/stringprintf.h"
#include "port/thread_annotations.h"
namespace platforms {
namespace darwinn {
namespace driver {
// Wraps a driver instance with additional functions that performs tests and
// verify results.
class DriverHelper : public api::Driver {
public:
DriverHelper(std::unique_ptr<api::Driver> driver, int max_pending_requests,
bool prefill_output_tensors = false,
size_t guard_area_size_bytes = 0);
// Destructor. Waits for pending tasks to avoid Submit callbacks
// acquiring otherwise-destructed mutex_. See b/111616778.
~DriverHelper() override {
if (IsOpen()) CHECK_OK(Close(api::Driver::ClosingMode::kGraceful));
}
Status Open(bool debug_mode, bool context_lost = false) final
LOCKS_EXCLUDED(mutex_);
Status Close(api::Driver::ClosingMode mode) final LOCKS_EXCLUDED(mutex_);
bool IsOpen() const final LOCKS_EXCLUDED(mutex_);
bool IsError() const final;
StatusOr<const api::PackageReference*> RegisterExecutableFile(
const std::string& executable_filename) final;
StatusOr<const api::PackageReference*> RegisterExecutableSerialized(
const std::string& executable_content) final;
StatusOr<const api::PackageReference*> RegisterExecutableSerialized(
const char* executable_content, size_t length) final;
Status UnregisterExecutable(
const api::PackageReference* executable_ref) final;
StatusOr<std::shared_ptr<api::Request>> CreateRequest(
const api::PackageReference* executable_ref) final;
Status Submit(std::shared_ptr<api::Request> request,
api::Request::Done done_callback) final;
Status Execute(std::shared_ptr<api::Request> request) final;
Status Execute(
const std::vector<std::shared_ptr<api::Request>>& requests) final;
Status Cancel(std::shared_ptr<api::Request> request) final;
Status CancelAllRequests() final;
uint64_t allocation_alignment_bytes() const final;
Buffer MakeBuffer(size_t size_bytes) const final;
void SetFatalErrorCallback(FatalErrorCallback callback) final;
void SetThermalWarningCallback(ThermalWarningCallback callback) final;
Status SetExecutionPreference(const api::PackageReference* package,
ExecutionPreference preference) final {
return OkStatus();
}
// Extensions to the Device interface to facilitate easier testing.
// Submits an inference request with given test vector.
Status Submit(const TestVector& test_vector, int batches)
LOCKS_EXCLUDED(mutex_);
// Submits an inference request and execute the specified callback on
// completion. |tag| is a user friendly name for tracking this request
// (typically the model name).
Status Submit(const std::string& tag,
const api::PackageReference* executable_ref,
const Buffer::NamedMap& input, const Buffer::NamedMap& output,
const Buffer::NamedMap& output_with_guard_areas,
api::Request::Done request_done) LOCKS_EXCLUDED(mutex_);
// Submits an inference request and verify output, with optional guard area
// sorrounding the output buffers. Dumps the output upon mismatch, if
// output_file_name is not empty.
Status Submit(
const std::string& tag, const api::PackageReference* executable_ref,
const std::string& output_file_name, const Buffer::NamedMap& input,
const Buffer::NamedMap& expected_output, const Buffer::NamedMap& output,
const Buffer::NamedMap& output_with_guard_areas) LOCKS_EXCLUDED(mutex_);
// Submits an inference request and verify output.
Status Submit(const std::string& tag,
const api::PackageReference* executable_ref,
const Buffer::NamedMap& input,
const Buffer::NamedMap& expected_output,
const Buffer::NamedMap& output) LOCKS_EXCLUDED(mutex_);
Status SetRealtimeMode(bool on) override;
Status SetExecutableTiming(const api::PackageReference* executable,
const api::Timing& timing) override;
void SetTelemeterInterface(
api::TelemeterInterface* telemeter_interface) override {}
void UpdateOperationalSettings(const OperationalSettings& settings) override {
driver_->UpdateOperationalSettings(settings);
}
private:
// Wrapped driver instance.
std::unique_ptr<api::Driver> driver_;
// Maximum number of pending requests.
const int max_pending_requests_{1};
// Current number of pending requests.
int pending_requests_ GUARDED_BY(mutex_){0};
// Total number of requests processed so far.
int total_requests_ GUARDED_BY(mutex_){0};
// Condition variable to synchronously wait for pending requests.
std::condition_variable cv_ GUARDED_BY(mutex_);
// Guards pending_requests_, cv_ and other internal states.
mutable std::mutex mutex_;
// Time at which first submit was called.
std::chrono::steady_clock::time_point first_submit_;
// A vector of roundtrip times for all requests in milliseconds. Roundtrip
// time is measured from when driver::submit is called until the callback is
// first received.
std::vector<double> roundtrip_times_ms_;
// A vector of verification times for all requests in milliseconds.
// Verification time is measured from when the callback is first received
// until the callback is completed.
std::vector<double> verification_times_ms_;
// If true, the output tensors are pre-filled with known data pattern.
// This helps catch incomplete output activations, i.e. when any parts of the
// output memory region are not overwritten.
const bool prefill_output_tensors_;
// If non-zero, leading and trailing guard areas would be allocated for every
// output buffer, and filled with known data pattern. These guard areas would
// then be checked when a request is completed, to detect data overflow.
// The size should be page-aligned for PCIe use cases.
// Note that in cases the driver always makes a copy of the output buffers,
// this mechanism would only catch driver-caused overruns.
const size_t guard_area_size_bytes_;
};
} // namespace driver
} // namespace darwinn
} // namespace platforms
#endif // DARWINN_DRIVER_DRIVER_HELPER_H_