Skip to content

Commit

Permalink
Improved config
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Sep 16, 2024
1 parent a9e1a2d commit 78ba923
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 16 deletions.
11 changes: 2 additions & 9 deletions lib/informers/configs.rb
Original file line number Diff line number Diff line change
@@ -1,18 +1,11 @@
module Informers
class PretrainedConfig
attr_reader :model_type, :problem_type, :id2label, :label2id

def initialize(config_json)
@is_encoder_decoder = false

@model_type = config_json["model_type"]
@problem_type = config_json["problem_type"]
@id2label = config_json["id2label"]
@label2id = config_json["label2id"]
@config_json = config_json
end

def [](key)
instance_variable_get("@#{key}")
@config_json[key.to_s]
end

def self.from_pretrained(
Expand Down
6 changes: 3 additions & 3 deletions lib/informers/models.rb
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,18 @@ def self.from_pretrained(
end

const_get(:MODEL_CLASS_MAPPINGS).each do |model_class_mapping|
model_info = model_class_mapping[config.model_type]
model_info = model_class_mapping[config[:model_type]]
if !model_info
next # Item not found in this mapping
end
return model_info[1].from_pretrained(pretrained_model_name_or_path, **options)
end

if const_defined?(:BASE_IF_FAIL)
warn "Unknown model class #{config.model_type.inspect}, attempting to construct from base class."
warn "Unknown model class #{config[:model_type].inspect}, attempting to construct from base class."
PreTrainedModel.from_pretrained(pretrained_model_name_or_path, **options)
else
raise Error, "Unsupported model type: #{config.model_type}"
raise Error, "Unsupported model type: #{config[:model_type]}"
end
end
end
Expand Down
8 changes: 4 additions & 4 deletions lib/informers/pipelines.rb
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ def call(texts, top_k: 1)
outputs = @model.(model_inputs)

function_to_apply =
if @model.config.problem_type == "multi_label_classification"
if @model.config[:problem_type] == "multi_label_classification"
->(batch) { Utils.sigmoid(batch) }
else
->(batch) { Utils.softmax(batch) } # single_label_classification (default)
end

id2label = @model.config.id2label
id2label = @model.config[:id2label]

to_return = []
outputs.logits.each do |batch|
Expand Down Expand Up @@ -70,7 +70,7 @@ def call(
outputs = @model.(model_inputs)

logits = outputs.logits
id2label = @model.config.id2label
id2label = @model.config[:id2label]

to_return = []
logits.length.times do |i|
Expand Down Expand Up @@ -281,7 +281,7 @@ class ZeroShotClassificationPipeline < Pipeline
def initialize(**options)
super(**options)

@label2id = @model.config.label2id.transform_keys(&:downcase)
@label2id = @model.config[:label2id].transform_keys(&:downcase)

@entailment_id = @label2id["entailment"]
if @entailment_id.nil?
Expand Down

0 comments on commit 78ba923

Please sign in to comment.