首页
社区
课程
招聘
[原创]第二题 CN星际基地
2023-9-4 12:48 7186

[原创]第二题 CN星际基地

2023-9-4 12:48
7186

39个不同的4位3进制数顺序排列成4*39矩阵,要求每行和为0
代码由chatgpt生成,据说叫回溯算法和剪枝算法,手动优化增加一些限制条件:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
all_precomputed_base_3 = None#seq_val -> col
 
available_sequences = None#seq_id -> seq_val
 
all_precomputed_remain_zero = None#row seq_id  -> remain 0
 
all_precomputed_remain_neg = None#row seq_id  -> remain -1
 
all_precomputed_remain_pos = None#row seq_id  -> remain 1
 
def to_base_3(num):
    digits = []
    while num:
        rem = num % 3
        num //= 3
        digits.append(rem)
    return digits[::-1]
 
def is_valid(matrix, col, cur_seq_id):
    remaining_cols = len(matrix[0]) - col - 1
    remaining_seq = len(available_sequences) - cur_seq_id - 1
    if remaining_seq < remaining_cols:
        # print("abort 1",remaining_cols, remaining_seq)
        return False
    if col>0:
        for i_c in range(col):
            count_reve=0
            for i_r in range(4):
                if matrix[i_r][i_c]+matrix[i_r][col]==0:
                    count_reve = count_reve+1
            if count_reve>=4:
                # print("abort 6",[matrix[i][i_c] for i in range(4)],cur_seq_id)
                return False
 
    for row in range(4):
        sum_row=sum(matrix[row][:col + 1])
        if abs(sum_row)>remaining_cols:
            # print("abort 2",sum_row,remaining_cols)
            return False
        remain_zero=all_precomputed_remain_zero[row][cur_seq_id]
        remain_one=all_precomputed_remain_pos[row][cur_seq_id]
        remain_neg=all_precomputed_remain_neg[row][cur_seq_id]
 
        if (sum_row>0 and sum_row-remain_neg>0) or (sum_row<0 and sum_row+remain_one<0):
            # print("abort 3",sum_row,remain_neg,remain_one)
            return False
         
        count_zero =    matrix[row][:col + 1].count(0)
        count_one =     matrix[row][:col + 1].count(1)
        count_neg_one = matrix[row][:col + 1].count(-1)
        if count_zero+remain_zero < 13 or count_one+remain_one < 13 or count_neg_one+remain_neg < 13:
            # print("abort 4",col+1)
            return False
        if count_zero > 13 or count_one > 13 or count_neg_one > 13:
            # print("abort 5",col+1)
            return False
 
    return True
 
def print_matrix(matrix):
    for row in matrix:
        print(row)
    print()
 
def all_base_3_sequences():
    sequences = [to_base_3(i) for i in range(81)]
    processed_sequences = []
    for sequence in sequences:
        processed_sequence = [(-1 if d == 2 else d) for d in sequence]
        for _ in range(4 - len(processed_sequence)):
            processed_sequence.insert(0, 0)
        processed_sequences.append(processed_sequence)
    return processed_sequences
 
@python
def check_md5(s):
    import hashlib
    md5 = hashlib.md5(s.encode("utf-8"))
    hash_value = md5.hexdigest()
 
    return hash_value
 
@python
def py_exit():
    input()
    exit(0)
 
def check_for_optimal(matrix):
    rows = ["".join(str(i) for i in row) for row in matrix]
    transformed_matrix_str = "".join(rows)
    transformed_matrix_str = transformed_matrix_str.replace("-1", "2")
    # print(transformed_matrix_str)
 
    hash_value=check_md5(transformed_matrix_str)
    # print(hash_value)
 
    if hash_value == "aac82b7ad77ab00dcef90ac079c9490d":
        print(transformed_matrix_str)
        print("Optimal solution found:")
        print(transformed_matrix_str)
        print(hash_value)
        print_matrix(matrix)
        py_exit()
         
def backtracking(matrix, col, last_seq_id):
    if col >= len(matrix[0]):
        if all(sum(row) == 0 for row in matrix):
            check_for_optimal(matrix)
        return
     
    for cur_seq_id in range(last_seq_id+1,len(available_sequences)):
        sequence = available_sequences[cur_seq_id]
        if last_seq_id<5:
            print(col,last_seq_id,cur_seq_id,sequence)
            print()
        digits = all_precomputed_base_3[sequence]
 
        for i in range(4):
            matrix[i][col] = digits[i]
 
 
        if is_valid(matrix, col, cur_seq_id):
            backtracking(matrix, col + 1, cur_seq_id)
 
def solve(from_seq):
    n_rows = 4
    n_cols = 39
    matrix = [[0 for _ in range(n_cols)] for _ in range(n_rows)]
    global available_sequences,all_precomputed_base_3,all_precomputed_remain_zero,all_precomputed_remain_pos,all_precomputed_remain_neg
    available_sequences = list(set(range(from_seq,81)) - {0, 40, 80})
    all_precomputed_base_3 = all_base_3_sequences()
     
    #计算各行在元素后剩余的0、-1、1数量
    all_maxtri_digits  = [[0 for _ in range(len(available_sequences))] for _ in range(n_rows)]
    for seq_id in range(len(available_sequences)):
        sequence=available_sequences[seq_id]
        sequence_digits = all_precomputed_base_3[sequence]
        for i in range(4):
            all_maxtri_digits[i][seq_id] = sequence_digits[i]
    all_precomputed_remain_zero = [[all_maxtri_digits[row_index][col_index+1:].count(0) for col_index in range(len(all_maxtri_digits[0]))] for row_index in range(len(all_maxtri_digits))]
    all_precomputed_remain_pos = [[all_maxtri_digits[row_index][col_index+1:].count(1) for col_index in range(len(all_maxtri_digits[0]))] for row_index in range(len(all_maxtri_digits))]
    all_precomputed_remain_neg = [[all_maxtri_digits[row_index][col_index+1:].count(-1) for col_index in range(len(all_maxtri_digits[0]))] for row_index in range(len(all_maxtri_digits))]
    print(available_sequences)
    print()
    print(all_precomputed_base_3)
    print()
    print(all_precomputed_remain_zero)
    print()
    print(all_precomputed_remain_neg)
    print()
    print(all_precomputed_remain_pos)
    print()
    backtracking(matrix, 0, -1)
 
if __name__ == "__main__":
    solve(1)

用codon跑大概1小时出结果,python大概2.6小时

主要是后面给的提示,每个元素13个符号直觉,猜到了,不能有相反元素这个优化效果很大。理论上可以再优化,不用全局变量,改成传参,每次选择元素时将后面相反的元素移除,并更新剩余0、1、-1的数量。

不带提示的情况,两小时后还在123456的序列,或许可以用多线程分布式计算,但是复杂度太高也优化不了多少。


[培训]内核驱动高级班,冲击BAT一流互联网大厂工作,每周日13:00-18:00直播授课

最后于 2023-9-4 13:13 被kanxue编辑 ,原因:
收藏
点赞1
打赏
分享
最新回复 (0)
游客
登录 | 注册 方可回帖
返回