Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
elk-cloner committed Feb 15, 2025
1 parent 14d993b commit 042f225
Show file tree
Hide file tree
Showing 8 changed files with 1,391 additions and 155 deletions.
11 changes: 0 additions & 11 deletions .vscode/settings.json

This file was deleted.

132 changes: 95 additions & 37 deletions candle-examples/examples/mllama/config.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
use std::default;

use anyhow::Ok;
use anyhow::{bail, Error as E, Result};
use candle::{DType, Device, Tensor};
use candle_nn::{init, Embedding, Module, VarBuilder};
use hf_hub::{api::sync::Api, Repo, RepoType};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;

#[derive(Debug, Clone, Serialize, serde::Deserialize, Default)]
pub enum Llama3RopeType {
Expand All @@ -28,64 +22,64 @@ pub struct Llama3RopeConfig {
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct MllamaTextConfig {
#[serde(default = "MllamaTextConfig::vocab_size")]
vocab_size: i32,
pub vocab_size: usize,
#[serde(default = "MllamaTextConfig::hidden_size")]
hidden_size: i32,
pub hidden_size: usize,
#[serde(default = "MllamaTextConfig::hidden_act")]
hidden_act: String,
pub hidden_act: String,
#[serde(default = "MllamaTextConfig::num_hidden_layers")]
num_hidden_layers: i32,
pub num_hidden_layers: usize,
#[serde(default = "MllamaTextConfig::num_attention_heads")]
num_attention_heads: i32,
pub num_attention_heads: usize,
#[serde(default = "MllamaTextConfig::num_key_value_heads")]
num_key_value_heads: i32,
pub num_key_value_heads: usize,
#[serde(default = "MllamaTextConfig::intermediate_size")]
intermediate_size: i32,
pub intermediate_size: usize,
#[serde(default = "MllamaTextConfig::rope_theta")]
rope_theta: f32,
pub rope_theta: f32,
#[serde(default = "MllamaTextConfig::rope_scaling")]
rope_scaling: Option<Llama3RopeConfig>,
pub rope_scaling: Option<Llama3RopeConfig>,
#[serde(default = "MllamaTextConfig::rms_norm_eps")]
rms_norm_eps: f32,
pub rms_norm_eps: f32,
#[serde(default = "MllamaTextConfig::max_position_embeddings")]
max_position_embeddings: i32,
pub max_position_embeddings: usize,
#[serde(default = "MllamaTextConfig::initializer_range")]
initializer_range: f32,
pub initializer_range: f32,
#[serde(default = "MllamaTextConfig::use_cache")]
use_cache: bool,
pub use_cache: bool,
#[serde(default = "MllamaTextConfig::tie_word_embeddings")]
tie_word_embeddings: bool,
pub tie_word_embeddings: bool,
#[serde(default = "MllamaTextConfig::cross_attention_layers")]
cross_attention_layers: Option<Vec<i32>>,
pub cross_attention_layers: Option<Vec<usize>>,
#[serde(default = "MllamaTextConfig::dropout")]
dropout: f32,
pub dropout: f32,
#[serde(default = "MllamaTextConfig::bos_token_id")]
bos_token_id: i32,
pub bos_token_id: usize,
#[serde(default = "MllamaTextConfig::eos_token_id")]
eos_token_id: i32,
pub eos_token_id: usize,
#[serde(default = "MllamaTextConfig::pad_token_id")]
pad_token_id: Option<i32>,
pub pad_token_id: Option<usize>,
}
impl MllamaTextConfig {
fn vocab_size() -> i32 {
fn vocab_size() -> usize {
128256
}
fn hidden_size() -> i32 {
fn hidden_size() -> usize {
4096
}
fn hidden_act() -> String {
String::from("silu")
}
fn num_hidden_layers() -> i32 {
fn num_hidden_layers() -> usize {
40
}
fn num_attention_heads() -> i32 {
fn num_attention_heads() -> usize {
32
}
fn num_key_value_heads() -> i32 {
fn num_key_value_heads() -> usize {
8
}
fn intermediate_size() -> i32 {
fn intermediate_size() -> usize {
14_336
}
fn rope_theta() -> f32 {
Expand All @@ -97,7 +91,7 @@ impl MllamaTextConfig {
fn rms_norm_eps() -> f32 {
1e-5
}
fn max_position_embeddings() -> i32 {
fn max_position_embeddings() -> usize {
131_072
}
fn initializer_range() -> f32 {
Expand All @@ -109,19 +103,19 @@ impl MllamaTextConfig {
fn tie_word_embeddings() -> bool {
false
}
fn cross_attention_layers() -> Option<Vec<i32>> {
fn cross_attention_layers() -> Option<Vec<usize>> {
None
}
fn dropout() -> f32 {
0.0
}
fn bos_token_id() -> i32 {
fn bos_token_id() -> usize {
128000
}
fn eos_token_id() -> i32 {
fn eos_token_id() -> usize {
128001
}
fn pad_token_id() -> Option<i32> {
fn pad_token_id() -> Option<usize> {
Some(128004)
}
}
Expand Down Expand Up @@ -256,3 +250,67 @@ impl MllamaConfig {
128256
}
}

#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ImagePreProcessorConfig {
#[serde(default = "ImagePreProcessorConfig::do_convert_rgb")]
pub do_convert_rgb: bool,
#[serde(default = "ImagePreProcessorConfig::do_normalize")]
pub do_normalize: bool,
#[serde(default = "ImagePreProcessorConfig::do_pad")]
pub do_pad: bool,
#[serde(default = "ImagePreProcessorConfig::do_rescale")]
pub do_rescale: bool,
#[serde(default = "ImagePreProcessorConfig::do_resize")]
pub do_resize: bool,
#[serde(default = "ImagePreProcessorConfig::image_mean")]
pub image_mean: Vec<f32>,
#[serde(default = "ImagePreProcessorConfig::image_std")]
pub image_std: Vec<f32>,
#[serde(default = "ImagePreProcessorConfig::max_image_tiles")]
pub max_image_tiles: usize,
#[serde(default = "ImagePreProcessorConfig::resample")]
pub resample: usize,
#[serde(default = "ImagePreProcessorConfig::rescale_factor")]
pub rescale_factor: f32,
#[serde(default = "ImagePreProcessorConfig::size")]
pub size: HashMap<String, usize>,
}
impl ImagePreProcessorConfig {
fn do_convert_rgb() -> bool {
true
}
fn do_normalize() -> bool {
true
}
fn do_pad() -> bool {
true
}
fn do_rescale() -> bool {
true
}
fn do_resize() -> bool {
true
}
fn image_mean() -> Vec<f32> {
vec![0.48145466, 0.4578275, 0.40821073]
}
fn image_std() -> Vec<f32> {
vec![0.26862954, 0.26130258, 0.27577711]
}
fn max_image_tiles() -> usize {
4
}
fn resample() -> usize {
2
}
fn rescale_factor() -> f32 {
0.00392156862745098
}
fn size() -> HashMap<String, usize> {
let mut size = HashMap::new();
size.insert(String::from("hight"), 448);
size.insert(String::from("width"), 448);
size
}
}
Loading

0 comments on commit 042f225

Please sign in to comment.