Skip to content

Commit

Permalink
Added object-detection pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Sep 16, 2024
1 parent fe23373 commit 20acf56
Show file tree
Hide file tree
Showing 6 changed files with 214 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
- Added `fill-mask` pipeline
- Added `image-classification` pipeline
- Added `image-feature-extraction` pipeline
- Added `object-detection` pipeline

## 1.0.3 (2024-08-29)

Expand Down
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,13 @@ classifier = Informers.pipeline("zero-shot-image-classification")
classifier.(URI("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"), ["cat", "dog", "tiger"])
```

Object detection [unreleased]

```ruby
detector = Informers.pipeline("object-detection")
detector.(URI("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"))
```

Image feature extraction [unreleased]

```ruby
Expand Down
32 changes: 32 additions & 0 deletions lib/informers/models.rb
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,18 @@ class CLIPPreTrainedModel < PreTrainedModel
class CLIPModel < CLIPPreTrainedModel
end

class DetrPreTrainedModel < PreTrainedModel
end

class DetrModel < DetrPreTrainedModel
end

class DetrForObjectDetection < DetrPreTrainedModel
def call(model_inputs)
DetrObjectDetectionOutput.new(*super(model_inputs))
end
end

MODEL_MAPPING_NAMES_ENCODER_ONLY = {
"bert" => ["BertModel", BertModel],
"nomic_bert" => ["NomicBertModel", NomicBertModel],
Expand All @@ -312,6 +324,7 @@ class CLIPModel < CLIPPreTrainedModel
"roberta" => ["RobertaModel", RobertaModel],
"xlm-roberta" => ["XLMRobertaModel", XLMRobertaModel],
"clip" => ["CLIPModel", CLIPModel],
"detr" => ["DetrModel", DetrModel],
"vit" => ["ViTModel", ViTModel]
}

Expand Down Expand Up @@ -343,6 +356,10 @@ class CLIPModel < CLIPPreTrainedModel
"vit" => ["ViTForImageClassification", ViTForImageClassification]
}

MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = {
"detr" => ["DetrForObjectDetection", DetrForObjectDetection]
}

MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES = {
}

Expand All @@ -354,6 +371,7 @@ class CLIPModel < CLIPPreTrainedModel
[MODEL_FOR_MASKED_LM_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
[MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
[MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
[MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
[MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]]
]

Expand Down Expand Up @@ -390,6 +408,10 @@ class AutoModelForImageClassification < PretrainedMixin
MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES]
end

class AutoModelForObjectDetection < PretrainedMixin
MODEL_CLASS_MAPPINGS = [MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES]
end

class AutoModelForImageFeatureExtraction < PretrainedMixin
MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES]
end
Expand Down Expand Up @@ -433,4 +455,14 @@ def initialize(start_logits, end_logits)
@end_logits = end_logits
end
end

class DetrObjectDetectionOutput < ModelOutput
attr_reader :logits, :pred_boxes

def initialize(logits, pred_boxes)
super()
@logits = logits
@pred_boxes = pred_boxes
end
end
end
55 changes: 55 additions & 0 deletions lib/informers/pipelines.rb
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,52 @@ def call(images, candidate_labels, hypothesis_template: "This is a photo of {}")
end
end

class ObjectDetectionPipeline < Pipeline
def call(images, threshold: 0.9, percentage: false)
is_batched = images.is_a?(Array)

if is_batched && images.length != 1
raise Error, "Object detection pipeline currently only supports a batch size of 1."
end
prepared_images = prepare_images(images)

image_sizes = percentage ? nil : prepared_images.map { |x| [x.height, x.width] }

model_inputs = @processor.(prepared_images).slice(:pixel_values, :pixel_mask)
output = @model.(model_inputs)

processed = @processor.feature_extractor.post_process_object_detection(output, threshold, image_sizes)

# Add labels
id2label = @model.config[:id2label]

# Format output
result =
processed.map do |batch|
batch[:boxes].map.with_index do |box, i|
{
score: batch[:scores][i],
label: id2label[batch[:classes][i].to_s],
box: get_bounding_box(box, !percentage)
}
end
end

is_batched ? result : result[0]
end

private

def get_bounding_box(box, as_integer)
if as_integer
box = box.map { |x| x.to_i }
end
xmin, ymin, xmax, ymax = box

{xmin:, ymin:, xmax:, ymax:}
end
end

class FeatureExtractionPipeline < Pipeline
def call(
texts,
Expand Down Expand Up @@ -622,6 +668,15 @@ def call(
},
type: "multimodal"
},
"object-detection" => {
pipeline: ObjectDetectionPipeline,
model: AutoModelForObjectDetection,
processor: AutoProcessor,
default: {
model: "Xenova/detr-resnet-50",
},
type: "multimodal"
},
"feature-extraction" => {
tokenizer: AutoTokenizer,
pipeline: FeatureExtractionPipeline,
Expand Down
101 changes: 99 additions & 2 deletions lib/informers/processors.rb
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,12 @@ def preprocess(

# do padding after rescaling/normalizing
if !do_pad.nil? ? do_pad : @do_pad
raise Todo
if @pad_size
padded = pad_image(pixel_data, [image.height, image.width, image.channels], @pad_size)
pixel_data, img_dims = padded # Update pixel data and image dimensions
elsif @size_divisibility
raise Todo
end
end

if !do_flip_channel_order.nil? ? do_flip_channel_order : @do_flip_channel_order
Expand Down Expand Up @@ -245,7 +250,98 @@ class CLIPFeatureExtractor < ImageFeatureExtractor
class ViTFeatureExtractor < ImageFeatureExtractor
end

class DetrFeatureExtractor < ImageFeatureExtractor
def call(images)
result = super(images)

# TODO support differently-sized images, for now assume all images are the same size.
# TODO support different mask sizes (not just 64x64)
# Currently, just fill pixel mask with 1s
mask_size = [result[:pixel_values].size, 64, 64]
pixel_mask =
mask_size[0].times.map do
mask_size[1].times.map do
mask_size[2].times.map do
1
end
end
end

result.merge(pixel_mask: pixel_mask)
end

def center_to_corners_format(v)
centerX, centerY, width, height = v
[
centerX - width / 2.0,
centerY - height / 2.0,
centerX + width / 2.0,
centerY + height / 2.0
]
end

def post_process_object_detection(outputs, threshold = 0.5, target_sizes = nil, is_zero_shot = false)
out_logits = outputs.logits
out_bbox = outputs.pred_boxes
batch_size, num_boxes, num_classes = out_logits.size, out_logits[0].size, out_logits[0][0].size

if !target_sizes.nil? && target_sizes.length != batch_size
raise Error, "Make sure that you pass in as many target sizes as the batch dimension of the logits"
end
to_return = []
batch_size.times do |i|
target_size = !target_sizes.nil? ? target_sizes[i] : nil
info = {
boxes: [],
classes: [],
scores: []
}
logits = out_logits[i]
bbox = out_bbox[i]

num_boxes.times do |j|
logit = logits[j]

indices = []
if is_zero_shot
raise Todo
else
# Get most probable class
max_index = Utils.max(logit)[1]

if max_index == num_classes - 1
# This is the background class, skip it
next
end
indices << max_index

# Compute softmax over classes
probs = Utils.softmax(logit)
end

indices.each do |index|
box = bbox[j]

# convert to [x0, y0, x1, y1] format
box = center_to_corners_format(box)
if !target_size.nil?
box = box.map.with_index { |x, i| x * target_size[(i + 1) % 2] }
end

info[:boxes] << box
info[:classes] << index
info[:scores] << probs[index]
end
end
to_return << info
end
to_return
end
end

class Processor
attr_reader :feature_extractor

def initialize(feature_extractor)
@feature_extractor = feature_extractor
end
Expand All @@ -258,7 +354,8 @@ def call(input, *args)
class AutoProcessor
FEATURE_EXTRACTOR_CLASS_MAPPING = {
"ViTFeatureExtractor" => ViTFeatureExtractor,
"CLIPFeatureExtractor" => CLIPFeatureExtractor
"CLIPFeatureExtractor" => CLIPFeatureExtractor,
"DetrFeatureExtractor" => DetrFeatureExtractor
}

PROCESSOR_CLASS_MAPPING = {}
Expand Down
20 changes: 20 additions & 0 deletions test/pipeline_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,26 @@ def test_zero_shot_image_classification
assert_in_delta 0.055, result[2][:score]
end

def test_object_detection
detector = Informers.pipeline("object-detection")
result = detector.("test/support/pipeline-cat-chonk.jpeg")
assert_equal 3, result.size

assert_in_delta 0.742, result[0][:score]
assert_equal "cat", result[0][:label]
assert_equal 177, result[0][:box][:xmin]
assert_equal 153, result[0][:box][:ymin]
assert_equal 885, result[0][:box][:xmax]
assert_equal 600, result[0][:box][:ymax]

assert_in_delta 0.726, result[1][:score]
assert_equal "bicycle", result[1][:label]
assert_equal 0, result[1][:box][:xmin]
assert_equal 0, result[1][:box][:ymin]
assert_equal 196, result[1][:box][:xmax]
assert_equal 413, result[1][:box][:ymax]
end

def test_image_feature_extraction
fe = Informers.pipeline("image-feature-extraction")
result = fe.("test/support/pipeline-cat-chonk.jpeg")
Expand Down

0 comments on commit 20acf56

Please sign in to comment.