Skip to content

Commit

Permalink
make sure no references column is updated
Browse files Browse the repository at this point in the history
use f-strings for visibility
  • Loading branch information
hichamlahlou committed Feb 21, 2025
1 parent 427c5b8 commit 2f2b529
Showing 1 changed file with 43 additions and 23 deletions.
66 changes: 43 additions & 23 deletions freppledb/input/commands/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,13 +187,33 @@ def getWeight(cls, **kwargs):
@classmethod
def run(cls, database=DEFAULT_DB_ALIAS, **kwargs):
with connections[database].cursor() as cursor:

cursor.execute(
"""
select rel_kcu.table_name as primary_table,
rel_kcu.column_name as primary_column
from information_schema.table_constraints tco
join information_schema.key_column_usage kcu
on tco.constraint_schema = kcu.constraint_schema
and tco.constraint_name = kcu.constraint_name
join information_schema.referential_constraints rco
on tco.constraint_schema = rco.constraint_schema
and tco.constraint_name = rco.constraint_name
join information_schema.key_column_usage rel_kcu
on rco.unique_constraint_schema = rel_kcu.constraint_schema
and rco.unique_constraint_name = rel_kcu.constraint_name
and kcu.ordinal_position = rel_kcu.ordinal_position
where tco.constraint_type = 'FOREIGN KEY'
"""
)
foreign_key_exists = [i for i in cursor]

# check 1: make sure the sequence has not reached 90% of the max value
cursor.execute(
"""
with cte as (
select sequencename from pg_sequences
where schemaname='public'
and sequencename not like 'django%' and sequencename not like 'auth%'
and last_value > 0.9 * max_value)
select s.relname as sequencename,
t.relname as tablename,
Expand All @@ -204,10 +224,10 @@ def run(cls, database=DEFAULT_DB_ALIAS, **kwargs):
inner join pg_namespace n on n.oid=t.relnamespace
inner join pg_attribute a on a.attrelid=t.oid and a.attnum=d.refobjsubid
inner join cte on cte.sequencename = s.relname
where s.relkind='S' and n.nspname = 'public';
where s.relkind='S' and n.nspname = 'public'
"""
)
to_update = [i for i in cursor]
to_update = [i for i in cursor if (i[1], i[2]) not in foreign_key_exists]
# check 2: make sure the max(id) is less than the sequence value
cursor.execute(
"""
Expand All @@ -225,38 +245,38 @@ def run(cls, database=DEFAULT_DB_ALIAS, **kwargs):
and t.relname not like 'django%' and t.relname not like 'auth%';
"""
)
sequences = [i for i in cursor if i[0] not in [j[0] for j in to_update]]
sequences = [
i
for i in cursor
if i[0]
not in [
j[0] for j in to_update
] # make sure this sequence wasn't cpatured by check 1
and (i[1], i[2])
not in foreign_key_exists # exclude sequences that are foreign keys
]
for seq in sequences:
cursor.execute("select max(%s) from %s" % (seq[2], seq[1]))
max_id = cursor.fetchone()[0]
if max_id and max_id > (seq[3] or 0):
to_update.append((seq[0], seq[1], seq[2]))

for i in to_update:
sequencename = i[0]
tablename = i[1]
columnname = i[2]
cursor.execute(
"""
f"""
WITH numbered_rows AS (
SELECT %s, ROW_NUMBER() OVER (ORDER BY id) AS new_id
FROM %s
SELECT {columnname}, ROW_NUMBER() OVER (ORDER BY {columnname}) AS new_id
FROM {tablename}
)
UPDATE %s
SET %s = numbered_rows.new_id
UPDATE {tablename}
SET {columnname} = numbered_rows.new_id
FROM numbered_rows
WHERE %s.%s = numbered_rows.%s;
SELECT setval('%s', (SELECT max(%s) FROM %s));
WHERE {tablename}.{columnname} = numbered_rows.{columnname};
SELECT setval('{sequencename}', (SELECT max({columnname}) FROM {tablename}));
"""
% (
i[2],
i[1], # the first 2 are for the cte
i[1], # UPDATE %s
i[2], # SET %s
i[1], # WHERE %s.%s 1/2
i[2], # WHERE %s.%s 2/2
i[2], # numbered_rows.%s
i[0], # setval('%s'
i[2], # SELECT max(%s)
i[1], # FROM %s
)
)
logger.info("updated sequence %s for table %s" % (i[0], i[1]))

Expand Down

0 comments on commit 2f2b529

Please sign in to comment.