diff --git a/lib/mini_sql.rb b/lib/mini_sql.rb index 3b6920f..78d549a 100644 --- a/lib/mini_sql.rb +++ b/lib/mini_sql.rb @@ -29,6 +29,7 @@ module Postgres autoload :PreparedConnection, "mini_sql/postgres/prepared_connection" autoload :PreparedCache, "mini_sql/postgres/prepared_cache" autoload :PreparedBinds, "mini_sql/postgres/prepared_binds" + autoload :PreparedBindsAutoArray, "mini_sql/postgres/prepared_binds_auto_array" end module ActiveRecordPostgres diff --git a/lib/mini_sql/abstract/prepared_binds.rb b/lib/mini_sql/abstract/prepared_binds.rb index 70a36a5..d432695 100644 --- a/lib/mini_sql/abstract/prepared_binds.rb +++ b/lib/mini_sql/abstract/prepared_binds.rb @@ -41,13 +41,12 @@ def bind_hash(sql, hash) def bind_array(sql, array) sql = sql.dup - param_i = 0 + param_i = -1 i = 0 binds = [] bind_names = [] sql.gsub!("?") do - param_i += 1 - array_wrap(array[param_i - 1]).map do |vv| + array_wrap(array[param_i += 1]).map do |vv| binds << vv i += 1 bind_names << [BindName.new("$#{i}")] diff --git a/lib/mini_sql/postgres/connection.rb b/lib/mini_sql/postgres/connection.rb index 38dd2d6..16c3801 100644 --- a/lib/mini_sql/postgres/connection.rb +++ b/lib/mini_sql/postgres/connection.rb @@ -3,7 +3,7 @@ module MiniSql module Postgres class Connection < MiniSql::Connection - attr_reader :raw_connection, :param_encoder, :deserializer_cache + attr_reader :raw_connection, :param_encoder, :deserializer_cache, :array_encoder def self.default_deserializer_cache @deserializer_cache ||= DeserializerCache.new @@ -52,7 +52,7 @@ def self.type_map(conn) def initialize(raw_connection, args = nil) @raw_connection = raw_connection @deserializer_cache = (args && args[:deserializer_cache]) || self.class.default_deserializer_cache - array_encoder = PG::TextEncoder::Array.new if args && args[:auto_encode_arrays] + @array_encoder = PG::TextEncoder::Array.new if args && args[:auto_encode_arrays] @param_encoder = (args && args[:param_encoder]) || InlineParamEncoder.new(self, array_encoder) @type_map = args && args[:type_map] end diff --git a/lib/mini_sql/postgres/prepared_binds_auto_array.rb b/lib/mini_sql/postgres/prepared_binds_auto_array.rb new file mode 100644 index 0000000..7bf5df8 --- /dev/null +++ b/lib/mini_sql/postgres/prepared_binds_auto_array.rb @@ -0,0 +1,61 @@ +# frozen_string_literal: true + +require "mini_sql/abstract/prepared_binds" + +module MiniSql + module Postgres + class PreparedBindsAutoArray < ::MiniSql::Abstract::PreparedBinds + + attr_reader :array_encoder + + def initialize(array_encoder) + @array_encoder = array_encoder + end + + def bind_hash(sql, hash) + sql = sql.dup + binds = [] + bind_names = [] + i = 0 + + hash.each do |k, v| + binds << (v.is_a?(Array) ? array_encoder.encode(v) : v) + bind_names << [BindName.new(k)] + bind_outputs = bind_output(i += 1) + + sql.gsub!(":#{k}") do + # ignore ::int and stuff like that + # $` is previous to match + if $` && $`[-1] != ":" + bind_outputs + else + ":#{k}" + end + end + end + [sql, binds, bind_names] + end + + def bind_array(sql, array) + sql = sql.dup + param_i = -1 + i = 0 + binds = [] + bind_names = [] + sql.gsub!("?") do + v = array[param_i += 1] + binds << (v.is_a?(Array) ? array_encoder.encode(v) : v) + i += 1 + bind_names << [BindName.new("$#{i}")] + bind_output(i) + end + [sql, binds, bind_names] + end + + def bind_output(i) + "$#{i}" + end + + end + end +end diff --git a/lib/mini_sql/postgres/prepared_connection.rb b/lib/mini_sql/postgres/prepared_connection.rb index b69029d..39a9290 100644 --- a/lib/mini_sql/postgres/prepared_connection.rb +++ b/lib/mini_sql/postgres/prepared_connection.rb @@ -13,7 +13,7 @@ def initialize(unprepared_connection) @param_encoder = unprepared_connection.param_encoder @prepared_cache = PreparedCache.new(@raw_connection) - @param_binder = PreparedBinds.new + @param_binder = unprepared.array_encoder ? PreparedBindsAutoArray.new(unprepared.array_encoder) : PreparedBinds.new end def build(_) diff --git a/test/mini_sql/postgres/connection_test.rb b/test/mini_sql/postgres/connection_test.rb index ac90efa..b63c034 100644 --- a/test/mini_sql/postgres/connection_test.rb +++ b/test/mini_sql/postgres/connection_test.rb @@ -155,17 +155,26 @@ def test_unamed_query assert_equal(row.column2, 3) end - def test_encode_array + def test_array_with_auto_encode_arrays connection = pg_connection(auto_encode_arrays: true) ints = [1, 2, 3] strings = %w[a b c] empty_array = [] - row = connection.query("select ?::int[] ints, ?::text[] strings, ?::int[] empty_array", ints, strings, empty_array).first + row = connection.query_single("select ?::int[], ?::text[], ?::int[]", ints, strings, empty_array) - assert_equal(row.ints, ints) - assert_equal(row.strings, strings) - assert_equal(row.empty_array, empty_array) + assert_equal(row, [ints, strings, empty_array]) + end + + def test_simple_with_auto_encode_arrays + connection = pg_connection(auto_encode_arrays: true) + + int = 1 + str = "str" + date = Date.new(2020, 10, 10) + row = connection.query_single("select ?, ?, ?::date", int, str, date) + + assert_equal(row, [int, str, date]) end end diff --git a/test/mini_sql/postgres/prepared_connection_test.rb b/test/mini_sql/postgres/prepared_connection_test.rb index edbe48a..32c28dd 100644 --- a/test/mini_sql/postgres/prepared_connection_test.rb +++ b/test/mini_sql/postgres/prepared_connection_test.rb @@ -82,4 +82,37 @@ def test_single_named_param assert_last_stmt "select $1, $1, $1" assert_equal %w[test test test], r end + + def test_array_with_auto_encode_arrays + connection = pg_connection(auto_encode_arrays: true).prepared + + ints = [1, 2, 3] + strings = %w[a b c] + empty_array = [] + row = connection.query_single("select ?::int[], ?::text[], ?::int[]", ints, strings, empty_array) + + assert_equal(row, [ints, strings, empty_array]) + end + + def test_simple_with_auto_encode_arrays + connection = pg_connection(auto_encode_arrays: true).prepared + + int = 1 + str = "str" + date = Date.new(2020, 10, 10) + row = connection.query_single("select ?::int, ?, ?::date", int, str, date) + + assert_equal(row, [int, str, date]) + end + + def test_hash_params_with_auto_encode_arrays + connection = pg_connection(auto_encode_arrays: true).prepared + + num = 1 + date = Date.new(2020, 10, 10) + ints = [1, 2, 3] + row = connection.query_single("select :num::int, :date::date, :ints::int[]", num: num, date: date, ints: ints) + + assert_equal(row, [num, date, ints]) + end end