首页
社区
课程
招聘
[原创]第二题 CN星际基地
发表于: 2023-9-4 12:48 8068

[原创]第二题 CN星际基地

2023-9-4 12:48
8068

39个不同的4位3进制数顺序排列成4*39矩阵,要求每行和为0
代码由chatgpt生成,据说叫回溯算法和剪枝算法,手动优化增加一些限制条件:

用codon跑大概1小时出结果,python大概2.6小时

主要是后面给的提示,每个元素13个符号直觉,猜到了,不能有相反元素这个优化效果很大。理论上可以再优化,不用全局变量,改成传参,每次选择元素时将后面相反的元素移除,并更新剩余0、1、-1的数量。

不带提示的情况,两小时后还在123456的序列,或许可以用多线程分布式计算,但是复杂度太高也优化不了多少。

all_precomputed_base_3 = None#seq_val -> col
 
available_sequences = None#seq_id -> seq_val
 
all_precomputed_remain_zero = None#row seq_id  -> remain 0
 
all_precomputed_remain_neg = None#row seq_id  -> remain -1
 
all_precomputed_remain_pos = None#row seq_id  -> remain 1
 
def to_base_3(num):
    digits = []
    while num:
        rem = num % 3
        num //= 3
        digits.append(rem)
    return digits[::-1]
 
def is_valid(matrix, col, cur_seq_id):
    remaining_cols = len(matrix[0]) - col - 1
    remaining_seq = len(available_sequences) - cur_seq_id - 1
    if remaining_seq < remaining_cols:
        # print("abort 1",remaining_cols, remaining_seq)
        return False
    if col>0:
        for i_c in range(col):
            count_reve=0
            for i_r in range(4):
                if matrix[i_r][i_c]+matrix[i_r][col]==0:
                    count_reve = count_reve+1
            if count_reve>=4:
                # print("abort 6",[matrix[i][i_c] for i in range(4)],cur_seq_id)
                return False
 
    for row in range(4):
        sum_row=sum(matrix[row][:col + 1])
        if abs(sum_row)>remaining_cols:
            # print("abort 2",sum_row,remaining_cols)
            return False
        remain_zero=all_precomputed_remain_zero[row][cur_seq_id]
        remain_one=all_precomputed_remain_pos[row][cur_seq_id]
        remain_neg=all_precomputed_remain_neg[row][cur_seq_id]
 
        if (sum_row>0 and sum_row-remain_neg>0) or (sum_row<0 and sum_row+remain_one<0):
            # print("abort 3",sum_row,remain_neg,remain_one)
            return False
         
        count_zero =    matrix[row][:col + 1].count(0)
        count_one =     matrix[row][:col + 1].count(1)
        count_neg_one = matrix[row][:col + 1].count(-1)
        if count_zero+remain_zero < 13 or count_one+remain_one < 13 or count_neg_one+remain_neg < 13:
            # print("abort 4",col+1)
            return False
        if count_zero > 13 or count_one > 13 or count_neg_one > 13:
            # print("abort 5",col+1)
            return False
 
    return True
 
def print_matrix(matrix):
    for row in matrix:
        print(row)
    print()
 
def all_base_3_sequences():
    sequences = [to_base_3(i) for i in range(81)]
    processed_sequences = []
    for sequence in sequences:
        processed_sequence = [(-1 if d == 2 else d) for d in sequence]
        for _ in range(4 - len(processed_sequence)):
            processed_sequence.insert(0, 0)
        processed_sequences.append(processed_sequence)
    return processed_sequences
 
@python
def check_md5(s):
    import hashlib
    md5 = hashlib.md5(s.encode("utf-8"))
    hash_value = md5.hexdigest()
 
    return hash_value
 
@python
def py_exit():
    input()
    exit(0)
 
def check_for_optimal(matrix):
    rows = ["".join(str(i) for i in row) for row in matrix]
    transformed_matrix_str = "".join(rows)
    transformed_matrix_str = transformed_matrix_str.replace("-1", "2")
    # print(transformed_matrix_str)
 
    hash_value=check_md5(transformed_matrix_str)
    # print(hash_value)
 
    if hash_value == "aac82b7ad77ab00dcef90ac079c9490d":
        print(transformed_matrix_str)
        print("Optimal solution found:")
        print(transformed_matrix_str)
        print(hash_value)
        print_matrix(matrix)
        py_exit()
         
def backtracking(matrix, col, last_seq_id):
    if col >= len(matrix[0]):
        if all(sum(row) == 0 for row in matrix):
            check_for_optimal(matrix)
        return
     
    for cur_seq_id in range(last_seq_id+1,len(available_sequences)):
        sequence = available_sequences[cur_seq_id]
        if last_seq_id<5:
            print(col,last_seq_id,cur_seq_id,sequence)
            print()
        digits = all_precomputed_base_3[sequence]
 
        for i in range(4):
            matrix[i][col] = digits[i]
 
 
        if is_valid(matrix, col, cur_seq_id):
            backtracking(matrix, col + 1, cur_seq_id)
 
def solve(from_seq):
    n_rows = 4
    n_cols = 39
    matrix = [[0 for _ in range(n_cols)] for _ in range(n_rows)]
    global available_sequences,all_precomputed_base_3,all_precomputed_remain_zero,all_precomputed_remain_pos,all_precomputed_remain_neg
    available_sequences = list(set(range(from_seq,81)) - {0, 40, 80})
    all_precomputed_base_3 = all_base_3_sequences()
     
    #计算各行在元素后剩余的0、-1、1数量
    all_maxtri_digits  = [[0 for _ in range(len(available_sequences))] for _ in range(n_rows)]
    for seq_id in range(len(available_sequences)):
        sequence=available_sequences[seq_id]
        sequence_digits = all_precomputed_base_3[sequence]
        for i in range(4):
            all_maxtri_digits[i][seq_id] = sequence_digits[i]
    all_precomputed_remain_zero = [[all_maxtri_digits[row_index][col_index+1:].count(0) for col_index in range(len(all_maxtri_digits[0]))] for row_index in range(len(all_maxtri_digits))]
    all_precomputed_remain_pos = [[all_maxtri_digits[row_index][col_index+1:].count(1) for col_index in range(len(all_maxtri_digits[0]))] for row_index in range(len(all_maxtri_digits))]
    all_precomputed_remain_neg = [[all_maxtri_digits[row_index][col_index+1:].count(-1) for col_index in range(len(all_maxtri_digits[0]))] for row_index in range(len(all_maxtri_digits))]
    print(available_sequences)
    print()
    print(all_precomputed_base_3)
    print()
    print(all_precomputed_remain_zero)
    print()
    print(all_precomputed_remain_neg)
    print()
    print(all_precomputed_remain_pos)
    print()
    backtracking(matrix, 0, -1)
 
if __name__ == "__main__":
    solve(1)
all_precomputed_base_3 = None#seq_val -> col
 
available_sequences = None#seq_id -> seq_val
 
all_precomputed_remain_zero = None#row seq_id  -> remain 0
 
all_precomputed_remain_neg = None#row seq_id  -> remain -1
 
all_precomputed_remain_pos = None#row seq_id  -> remain 1
 
def to_base_3(num):
    digits = []
    while num:
        rem = num % 3
        num //= 3
        digits.append(rem)
    return digits[::-1]
 
def is_valid(matrix, col, cur_seq_id):
    remaining_cols = len(matrix[0]) - col - 1
    remaining_seq = len(available_sequences) - cur_seq_id - 1
    if remaining_seq < remaining_cols:
        # print("abort 1",remaining_cols, remaining_seq)
        return False
    if col>0:
        for i_c in range(col):
            count_reve=0
            for i_r in range(4):
                if matrix[i_r][i_c]+matrix[i_r][col]==0:
                    count_reve = count_reve+1
            if count_reve>=4:
                # print("abort 6",[matrix[i][i_c] for i in range(4)],cur_seq_id)
                return False
 
    for row in range(4):
        sum_row=sum(matrix[row][:col + 1])
        if abs(sum_row)>remaining_cols:
            # print("abort 2",sum_row,remaining_cols)
            return False
        remain_zero=all_precomputed_remain_zero[row][cur_seq_id]
        remain_one=all_precomputed_remain_pos[row][cur_seq_id]
        remain_neg=all_precomputed_remain_neg[row][cur_seq_id]
 
        if (sum_row>0 and sum_row-remain_neg>0) or (sum_row<0 and sum_row+remain_one<0):
            # print("abort 3",sum_row,remain_neg,remain_one)
            return False
         
        count_zero =    matrix[row][:col + 1].count(0)
        count_one =     matrix[row][:col + 1].count(1)
        count_neg_one = matrix[row][:col + 1].count(-1)
        if count_zero+remain_zero < 13 or count_one+remain_one < 13 or count_neg_one+remain_neg < 13:
            # print("abort 4",col+1)
            return False
        if count_zero > 13 or count_one > 13 or count_neg_one > 13:
            # print("abort 5",col+1)
            return False
 
    return True
 
def print_matrix(matrix):
    for row in matrix:
        print(row)
    print()
 
def all_base_3_sequences():
    sequences = [to_base_3(i) for i in range(81)]
    processed_sequences = []
    for sequence in sequences:
        processed_sequence = [(-1 if d == 2 else d) for d in sequence]
        for _ in range(4 - len(processed_sequence)):
            processed_sequence.insert(0, 0)
        processed_sequences.append(processed_sequence)
    return processed_sequences
 
@python
def check_md5(s):

[招生]科锐逆向工程师培训(2024年11月15日实地,远程教学同时开班, 第51期)

最后于 2023-9-4 13:13 被kanxue编辑 ,原因:
收藏
免费 1
支持
分享
最新回复 (0)
游客
登录 | 注册 方可回帖
返回
//