98 lines
		
	
	
		
			3.0 KiB
		
	
	
	
		
			Ruby
		
	
	
	
			
		
		
	
	
			98 lines
		
	
	
		
			3.0 KiB
		
	
	
	
		
			Ruby
		
	
	
	
# frozen_string_literal: true
 | 
						|
 | 
						|
module Gitlab
 | 
						|
  module Database
 | 
						|
    class UnidirectionalCopyTrigger
 | 
						|
      def self.on_table(table_name, connection:)
 | 
						|
        new(table_name, connection)
 | 
						|
      end
 | 
						|
 | 
						|
      def name(from_column_names, to_column_names)
 | 
						|
        from_column_names, to_column_names = check_column_names!(from_column_names, to_column_names)
 | 
						|
 | 
						|
        unchecked_name(from_column_names, to_column_names)
 | 
						|
      end
 | 
						|
 | 
						|
      def create(from_column_names, to_column_names, trigger_name: nil)
 | 
						|
        from_column_names, to_column_names = check_column_names!(from_column_names, to_column_names)
 | 
						|
        trigger_name ||= unchecked_name(from_column_names, to_column_names)
 | 
						|
 | 
						|
        assignment_clauses = assignment_clauses_for_columns(from_column_names, to_column_names)
 | 
						|
 | 
						|
        connection.execute(<<~SQL)
 | 
						|
          CREATE OR REPLACE FUNCTION #{trigger_name}()
 | 
						|
          RETURNS trigger AS
 | 
						|
          $BODY$
 | 
						|
          BEGIN
 | 
						|
            #{assignment_clauses};
 | 
						|
            RETURN NEW;
 | 
						|
          END;
 | 
						|
          $BODY$
 | 
						|
          LANGUAGE 'plpgsql'
 | 
						|
          VOLATILE
 | 
						|
        SQL
 | 
						|
 | 
						|
        connection.execute(<<~SQL)
 | 
						|
          DROP TRIGGER IF EXISTS #{trigger_name}
 | 
						|
          ON #{quoted_table_name}
 | 
						|
        SQL
 | 
						|
 | 
						|
        connection.execute(<<~SQL)
 | 
						|
          CREATE TRIGGER #{trigger_name}
 | 
						|
          BEFORE INSERT OR UPDATE
 | 
						|
          ON #{quoted_table_name}
 | 
						|
          FOR EACH ROW
 | 
						|
          EXECUTE FUNCTION #{trigger_name}()
 | 
						|
        SQL
 | 
						|
      end
 | 
						|
 | 
						|
      def drop(trigger_name)
 | 
						|
        connection.execute("DROP TRIGGER IF EXISTS #{trigger_name} ON #{quoted_table_name}")
 | 
						|
        connection.execute("DROP FUNCTION IF EXISTS #{trigger_name}()")
 | 
						|
      end
 | 
						|
 | 
						|
      private
 | 
						|
 | 
						|
      attr_reader :table_name, :connection
 | 
						|
 | 
						|
      def initialize(table_name, connection)
 | 
						|
        @table_name = table_name
 | 
						|
        @connection = connection
 | 
						|
      end
 | 
						|
 | 
						|
      def quoted_table_name
 | 
						|
        @quoted_table_name ||= connection.quote_table_name(table_name)
 | 
						|
      end
 | 
						|
 | 
						|
      def check_column_names!(from_column_names, to_column_names)
 | 
						|
        from_column_names = Array.wrap(from_column_names)
 | 
						|
        to_column_names = Array.wrap(to_column_names)
 | 
						|
 | 
						|
        unless from_column_names.size == to_column_names.size
 | 
						|
          raise ArgumentError, 'number of source and destination columns must match'
 | 
						|
        end
 | 
						|
 | 
						|
        [from_column_names, to_column_names]
 | 
						|
      end
 | 
						|
 | 
						|
      def unchecked_name(from_column_names, to_column_names)
 | 
						|
        joined_column_names = from_column_names.zip(to_column_names).flatten.join('_')
 | 
						|
        'trigger_' + Digest::SHA256.hexdigest("#{table_name}_#{joined_column_names}").first(12)
 | 
						|
      end
 | 
						|
 | 
						|
      def assignment_clauses_for_columns(from_column_names, to_column_names)
 | 
						|
        combined_column_names = to_column_names.zip(from_column_names)
 | 
						|
 | 
						|
        assignment_clauses = combined_column_names.map do |(new_name, old_name)|
 | 
						|
          new_name = connection.quote_column_name(new_name)
 | 
						|
          old_name = connection.quote_column_name(old_name)
 | 
						|
 | 
						|
          "NEW.#{new_name} := NEW.#{old_name}"
 | 
						|
        end
 | 
						|
 | 
						|
        assignment_clauses.join(";\n  ")
 | 
						|
      end
 | 
						|
    end
 | 
						|
  end
 | 
						|
end
 |