98 lines
		
	
	
		
			3.0 KiB
		
	
	
	
		
			Ruby
		
	
	
	
			
		
		
	
	
			98 lines
		
	
	
		
			3.0 KiB
		
	
	
	
		
			Ruby
		
	
	
	
| # frozen_string_literal: true
 | |
| 
 | |
| module Gitlab
 | |
|   module Database
 | |
|     class UnidirectionalCopyTrigger
 | |
|       def self.on_table(table_name, connection:)
 | |
|         new(table_name, connection)
 | |
|       end
 | |
| 
 | |
|       def name(from_column_names, to_column_names)
 | |
|         from_column_names, to_column_names = check_column_names!(from_column_names, to_column_names)
 | |
| 
 | |
|         unchecked_name(from_column_names, to_column_names)
 | |
|       end
 | |
| 
 | |
|       def create(from_column_names, to_column_names, trigger_name: nil)
 | |
|         from_column_names, to_column_names = check_column_names!(from_column_names, to_column_names)
 | |
|         trigger_name ||= unchecked_name(from_column_names, to_column_names)
 | |
| 
 | |
|         assignment_clauses = assignment_clauses_for_columns(from_column_names, to_column_names)
 | |
| 
 | |
|         connection.execute(<<~SQL)
 | |
|           CREATE OR REPLACE FUNCTION #{trigger_name}()
 | |
|           RETURNS trigger AS
 | |
|           $BODY$
 | |
|           BEGIN
 | |
|             #{assignment_clauses};
 | |
|             RETURN NEW;
 | |
|           END;
 | |
|           $BODY$
 | |
|           LANGUAGE 'plpgsql'
 | |
|           VOLATILE
 | |
|         SQL
 | |
| 
 | |
|         connection.execute(<<~SQL)
 | |
|           DROP TRIGGER IF EXISTS #{trigger_name}
 | |
|           ON #{quoted_table_name}
 | |
|         SQL
 | |
| 
 | |
|         connection.execute(<<~SQL)
 | |
|           CREATE TRIGGER #{trigger_name}
 | |
|           BEFORE INSERT OR UPDATE
 | |
|           ON #{quoted_table_name}
 | |
|           FOR EACH ROW
 | |
|           EXECUTE FUNCTION #{trigger_name}()
 | |
|         SQL
 | |
|       end
 | |
| 
 | |
|       def drop(trigger_name)
 | |
|         connection.execute("DROP TRIGGER IF EXISTS #{trigger_name} ON #{quoted_table_name}")
 | |
|         connection.execute("DROP FUNCTION IF EXISTS #{trigger_name}()")
 | |
|       end
 | |
| 
 | |
|       private
 | |
| 
 | |
|       attr_reader :table_name, :connection
 | |
| 
 | |
|       def initialize(table_name, connection)
 | |
|         @table_name = table_name
 | |
|         @connection = connection
 | |
|       end
 | |
| 
 | |
|       def quoted_table_name
 | |
|         @quoted_table_name ||= connection.quote_table_name(table_name)
 | |
|       end
 | |
| 
 | |
|       def check_column_names!(from_column_names, to_column_names)
 | |
|         from_column_names = Array.wrap(from_column_names)
 | |
|         to_column_names = Array.wrap(to_column_names)
 | |
| 
 | |
|         unless from_column_names.size == to_column_names.size
 | |
|           raise ArgumentError, 'number of source and destination columns must match'
 | |
|         end
 | |
| 
 | |
|         [from_column_names, to_column_names]
 | |
|       end
 | |
| 
 | |
|       def unchecked_name(from_column_names, to_column_names)
 | |
|         joined_column_names = from_column_names.zip(to_column_names).flatten.join('_')
 | |
|         'trigger_' + Digest::SHA256.hexdigest("#{table_name}_#{joined_column_names}").first(12)
 | |
|       end
 | |
| 
 | |
|       def assignment_clauses_for_columns(from_column_names, to_column_names)
 | |
|         combined_column_names = to_column_names.zip(from_column_names)
 | |
| 
 | |
|         assignment_clauses = combined_column_names.map do |(new_name, old_name)|
 | |
|           new_name = connection.quote_column_name(new_name)
 | |
|           old_name = connection.quote_column_name(old_name)
 | |
| 
 | |
|           "NEW.#{new_name} := NEW.#{old_name}"
 | |
|         end
 | |
| 
 | |
|         assignment_clauses.join(";\n  ")
 | |
|       end
 | |
|     end
 | |
|   end
 | |
| end
 |